Описание¶

Задачи¶

Испольуется три задачи:

  • DM -- двухальтернативный выбор
  • Romo -- сравнение двух сигнало через задержку
  • CtxDM -- DM с контекстом Вход состоит из одного контекстного входа, одного стимула, 6 входов, кодирующих задачи. Выход как и раньше состоит из трех частей: контекстный выход, выходы принятия решения.

Сеть¶

Сеть состоит из lif AdEx нейронов

Импорт всех необходимых библиотек¶

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cgtasknet.net.lifadex import SNNlifadex
from cgtasknet.tasks.reduce import (
    CtxDMTaskParameters,
    DMTaskParameters,
    DMTaskRandomModParameters,
    MultyReduceTasks,
    RomoTaskParameters,
    RomoTaskRandomModParameters,
)
from norse.torch.functional.lif_adex import LIFAdExParameters
from tqdm import tqdm

Определяем устройство¶

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"{device=}")
device=device(type='cuda', index=0)

Функция построения входов и выходов¶

In [3]:
import os


def plot_results(inputs, target_outputs, outputs):
    if isinstance(inputs, torch.Tensor) and isinstance(target_outputs, torch.Tensor):
        inputs, t_outputs = (
            inputs.detach().cpu().numpy(),
            target_outputs.detach().cpu().numpy(),
        )
    for bath in range(min(batch_size, 20)):
        fig = plt.figure(figsize=(15, 3))
        ax1 = fig.add_subplot(141)
        plt.title("Inputs")
        plt.xlabel("$time, ms$")
        plt.ylabel("$Magnitude$")
        for i in range(3):
            plt.plot(inputs[:, bath, i].T, label=rf"$in_{i + 1}$")
        plt.legend()
        plt.tight_layout()

        ax2 = fig.add_subplot(142)
        plt.title("Task code (context)")
        plt.xticks(np.arange(1, len(tasks) + 1), sorted(tasks), rotation=90)
        plt.yticks([])
        for i in range(3, inputs.shape[-1]):
            plt.plot([i - 2] * 2, [0, inputs[0, bath, i]])
        plt.tight_layout()

        ax3 = fig.add_subplot(143)
        plt.title("Target output")
        plt.xlabel("$time, ms$")
        for i in range(t_outputs.shape[-1]):
            plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
        plt.legend()
        plt.tight_layout()

        ax4 = fig.add_subplot(144)
        plt.title("Real output")
        plt.xlabel("$time, ms$")
        for i in range(outputs.shape[-1]):
            plt.plot(
                outputs.detach().cpu().numpy()[:, bath, i], label=rf"$out_{i + 1}$"
            )
        plt.legend()
        plt.tight_layout()
        if not os.path.exists("figures"):
            os.mkdir("figures")
        plt.savefig(f"figures{os.sep}network_outputs_{name}_batch_{bath}.pdf")
        plt.show()
        plt.close()

Определяем датасет¶

Датасет будет состоять из трех типов задач:

  • DM задача;
  • Romo задача;
  • CtxDM задача. Параметры для последней задачи аналогичны DM задаче

Параметры датасета:¶

In [4]:
batch_size = 100
number_of_epochs = 2000
number_of_tasks = 1
romo_parameters = RomoTaskRandomModParameters(
    romo=RomoTaskParameters(
        delay=0.1,
        positive_shift_delay_time=1.4,
        trial_time=0.1,
        positive_shift_trial_time=0.2,
    ),
)
dm_parameters = DMTaskRandomModParameters(
    dm=DMTaskParameters(trial_time=0.1, positive_shift_trial_time=0.8)
)
ctx_parameters = CtxDMTaskParameters(dm=dm_parameters.dm)

Датасет¶

In [5]:
sigma = 0.5
tasks = ["RomoTask1", "RomoTask2", "DMTask1", "DMTask2", "CtxDMTask1", "CtxDMTask2"]
task_dict = {
    tasks[0]: romo_parameters,
    tasks[1]: romo_parameters,
    tasks[2]: dm_parameters,
    tasks[3]: dm_parameters,
    tasks[4]: ctx_parameters,
    tasks[5]: ctx_parameters,
}
Task = MultyReduceTasks(
    tasks=task_dict, batch_size=batch_size, delay_between=0, enable_fixation_delay=True
)

print("Task parameters:")
for key in task_dict:
    print(f"{key}:\n{task_dict[key]}\n")

print(f"inputs/outputs: {Task.feature_and_act_size[0]}/{Task.feature_and_act_size[1]}")
Task parameters:
RomoTask1:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2)

RomoTask2:
RomoTaskRandomModParameters(romo=RomoTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=(None, None), delay=0.1, negative_shift_trial_time=0, positive_shift_trial_time=0.2, negative_shift_delay_time=0, positive_shift_delay_time=1.4), n_mods=2)

DMTask1:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2)

DMTask2:
DMTaskRandomModParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), n_mods=2)

CtxDMTask1:
CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None))

CtxDMTask2:
CtxDMTaskParameters(dm=DMTaskParameters(dt=0.001, trial_time=0.1, answer_time=0.15, value=None, negative_shift_trial_time=0, positive_shift_trial_time=0.8), context=None, value=(None, None))

inputs/outputs: 9/3

Пример генерации датасета¶

In [6]:
inputs, t_outputs = Task.dataset(n_trials=1)
for bath in range(min(batch_size, 10)):
    fig = plt.figure(figsize=(15, 3))
    ax1 = fig.add_subplot(131)
    plt.title("Inputs")
    plt.xlabel("$time, ms$")
    plt.ylabel("$Magnitude$")
    for i in range(3):
        plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
    plt.legend()
    plt.tight_layout()

    ax2 = fig.add_subplot(132)
    plt.title("Task code (context)")
    plt.xlabel("$time, ms$")
    for i in range(3, inputs.shape[-1]):
        plt.plot(inputs[:, bath, i], label=rf"$in_{i + 1}$")
    plt.legend()
    plt.tight_layout()

    ax3 = fig.add_subplot(133)
    plt.title("Target output")
    plt.xlabel("$time, ms$")
    for i in range(t_outputs.shape[-1]):
        plt.plot(t_outputs[:, bath, i], label=rf"$out_{i + 1}$")
    plt.legend()
    plt.tight_layout()
plt.show()
plt.close()
del inputs
del t_outputs

Инициализация сети и выгрузка на decive¶

In [7]:
feature_size, output_size = Task.feature_and_act_size
hidden_size = 450

neuron_parameters = LIFAdExParameters(
    v_th=torch.as_tensor(0.65),
    tau_ada_inv=0.5 + (6 - 0.5) * torch.rand(hidden_size).to(device),
    alpha=100,
    method="super",
    # rho_reset = torch.as_tensor(5)
)
model = SNNlifadex(
    feature_size,
    hidden_size,
    output_size,
    neuron_parameters=neuron_parameters,
    tau_filter_inv=500,
).to(device)

Критерий и функция ошибки¶

In [8]:
learning_rate = 1e-2


class RMSELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()

    def forward(self, yhat, y):
        return torch.sqrt(self.mse(yhat, y))


criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Генерация всех эпох¶

Если память не позволяет, то необходимо генерировать каждую эпоху в основном цикле обучения

In [9]:
if False:
    list_inputs = []
    list_t_outputs = []
    for i in tqdm(range(number_of_epochs)):
        temp_input, temp_t_output = Task.dataset()
        temp_input.astype(dtype=np.float16)
        temp_t_output.astype(dtype=np.float16)
        temp_input[:, :, :] += np.random.normal(0, sigma, size=temp_input.shape)
        list_inputs.append(temp_input)
        list_t_outputs.append(temp_t_output)

JIT numba generator (test)¶

TODO: Необходимо добавить в cgtasknet и вызывать оттуда

In [10]:
from numba import njit, prange
import time


@njit(cache=True, parallel=True)
def every_bath_generator(
    start_sigma: float,
    stop_sigma: float,
    times: int = 1,
    batches: int = 1,
    actions: int = 1,
):
    data = np.zeros((times, batches, actions))
    for i in prange(batches):
        data[:, i, :] = np.random.normal(
            0, np.random.uniform(start_sigma, stop_sigma), size=(times, actions)
        )
    return data
In [11]:
every_bath_generator(0, 0)
Out[11]:
array([[[0.]]])

Основной цикл обучения¶

In [12]:
from cgtasknet.instruments.instrument_accuracy_network import correct_answer
from cgtasknet.net.states import LIFAdExRefracInitState

name = f"Train_dm_and_romo_task_reduce_lif_adex_without_refrac_random_delay_long_a_alpha_{neuron_parameters.alpha}_N_{hidden_size}"
init_state = LIFAdExRefracInitState(batch_size, hidden_size, device=device)
running_loss = 0
for i in tqdm(range(2000)):

    inputs, target_outputs = Task.dataset()
    inputs[:, :, :3] += every_bath_generator(
        0, sigma, inputs.shape[0], inputs.shape[1], 3
    )
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    optimizer.zero_grad()

    # forward + backward + optimize
    outputs, _ = model(inputs)

    loss = criterion(outputs, target_outputs)
    loss.backward()
    optimizer.step()

    # print statistics
    running_loss += loss.item()
    if i % 10 == 9:
        with open("log_multy.txt", "a") as f:
            f.write("epoch: {:d} loss: {:0.5f}\n".format(i + 1, running_loss / 10))
        running_loss = 0.0
        with torch.no_grad():
            torch.save(
                model.state_dict(),
                name,
            )
    if i % 10 == 9:

        result = 0
        for j in range(10):
            try:
                del inputs
            except:
                pass
            try:
                del target_outputs
            except:
                pass
            try:
                del outputs
            except:
                pass
            torch.cuda.empty_cache()
            inputs, target_outputs = Task.dataset(1, delay_between=0)
            inputs += np.random.normal(0, 0.01, size=inputs.shape)
            inputs = torch.from_numpy(inputs).type(torch.float).to(device)
            target_outputs = (
                torch.from_numpy(target_outputs).type(torch.float).to(device)
            )
            outputs = model(inputs)[0]
            answers = correct_answer(
                outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
            )
            result += torch.sum(answers).item()

        accuracy = result / batch_size / 10 * 100
        with open("accuracy_multy.txt", "a") as f:
            f.write(f"ecpoch = {i}; correct/all = {accuracy}\n")
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
print("Finished Training")
100%|██████████| 2000/2000 [3:28:05<00:00,  6.24s/it]  
Finished Training

Тестовая функция¶

In [13]:
def test_network(test_sigma:float, number_of_trials: int = 100, plot_data:bool=True):
    result = 0
    for j in tqdm(range(number_of_trials)):
        try:
            del inputs
        except:
            pass
        try:
            del target_outputs
        except:
            pass
        try:
            del outputs
        except:
            pass
        torch.cuda.empty_cache()
        inputs, target_outputs = Task.dataset(1, delay_between=0)
        inputs[:, :, :3] += every_bath_generator(0, test_sigma, inputs.shape[0], inputs.shape[1], 3)
        inputs = torch.from_numpy(inputs).type(torch.float).to(device)
        target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
        outputs = model(inputs)[0]
        answers = correct_answer(
            outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
        )
        result += torch.sum(answers).item()

    accuracy = result / batch_size / number_of_trials * 100
    if plot_data:
        plot_results(inputs, target_outputs, outputs)
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    return accuracy

Тестирование¶

np.random.normal(0, 0.01, size=(inputs.shape))

In [23]:
accuracy = test_network(0.01, 100)
print(accuracy)
100%|██████████| 100/100 [03:26<00:00,  2.06s/it]
96.03

Тестирование¶

np.random.normal(0, 0.05, size=(inputs.shape))

In [15]:
accuracy = test_network(0.05)
print(accuracy)
100%|██████████| 100/100 [03:20<00:00,  2.00s/it]
95.37

Тестирование¶

np.random.normal(0, 0.1, size=(inputs.shape))

In [16]:
accuracy = test_network(0.1)
print(accuracy)
100%|██████████| 100/100 [03:21<00:00,  2.01s/it]
95.38

Тестирование¶

np.random.normal(0, 0.5, size=(inputs.shape))

In [17]:
accuracy = test_network(0.5)
print(accuracy)
100%|██████████| 100/100 [03:19<00:00,  1.99s/it]
94.15
In [18]:
result = 0
for j in tqdm(range(1)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.5, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()

accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00,  2.26s/it]
0.89
In [19]:
result = 0
for j in tqdm(range(1)):
    try:
        del inputs
    except:
        pass
    try:
        del target_outputs
    except:
        pass
    try:
        del outputs
    except:
        pass
    torch.cuda.empty_cache()
    inputs, target_outputs = Task.dataset(1, delay_between=0)
    inputs[:, :, :3] += np.random.normal(0, 0.7, size=inputs[:, :, :3].shape)
    inputs = torch.from_numpy(inputs).type(torch.float).to(device)
    target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
    outputs = model(inputs)[0]
    answers = correct_answer(
        outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
    )
    result += torch.sum(answers).item()

accuracy = result / batch_size / 100 * 100
print(accuracy)
plot_results(inputs, target_outputs, outputs)
try:
    del inputs
except:
    pass
try:
    del target_outputs
except:
    pass
try:
    del outputs
except:
    pass
torch.cuda.empty_cache()
100%|██████████| 1/1 [00:02<00:00,  2.35s/it]
0.83
In [24]:
inputs = 0
outputs = 0
In [29]:
tau_ada_inv_distrib = neuron_parameters.tau_ada_inv.cpu().numpy()
np.save(f"tau_ada_inv_alpha={neuron_parameters.alpha}", tau_ada_inv_distrib)
In [38]:
lines = []
with open("accuracy_multy.txt", "r") as f:
    while line := f.readline():

        lines.append(float(line.split("=")[2]))
In [39]:
plt.figure(figsize=(8, 5))
plt.plot([*range(9, 2000, 10)], lines, ".", linestyle="--", markersize=5)
plt.ylabel(r"Accuracy%")
plt.xlabel(r"Epochs")
Out[39]:
Text(0.5, 0, 'Epochs')
In [41]:
start_sigma = 0
stop_sigma = 2
step_sigma = 0.05
sigma_array = np.arange(start_sigma, stop_sigma, step_sigma)
for test_sigma in tqdm(sigma_array):
    result = 0
    for j in range(20):
        try:
            del inputs
        except:
            pass
        try:
            del target_outputs
        except:
            pass
        try:
            del outputs
        except:
            pass
        torch.cuda.empty_cache()
        inputs, target_outputs = Task.dataset(1, delay_between=0)
        inputs[:, :, :3] += np.random.normal(0, test_sigma, size=inputs[:, :, :3].shape)
        inputs = torch.from_numpy(inputs).type(torch.float).to(device)
        target_outputs = torch.from_numpy(target_outputs).type(torch.float).to(device)
        outputs = model(inputs)[0]
        answers = correct_answer(
            outputs[:, :, 1:], target_outputs[:, :, 1:], target_outputs[:, :, 0]
        )
        result += torch.sum(answers).item()
    accuracy = result / batch_size / 20 * 100
    with open("accuracy_vs_noise.txt", "a") as f:
        f.write(f"noise={test_sigma}:accuracy={accuracy}\n")
100%|██████████| 40/40 [28:18<00:00, 42.46s/it]
In [42]:
import matplotlib.patches as patches

plt.style.use("ggplot")


def parser(line_text: str) -> tuple:
    """
    Function parses text in form:
    ```v_name_1=v1:v_name_2:v2```
    and return (v1, v2)
    :param line_text:
    :return: (v1, v2)
    """
    line_text = line_text.split(":")
    v1 = line_text[0].split("=")[1]
    v2 = line_text[1].split("=")[1]
    return float(v1), float(v2)


x, y = [], []
# with open('accuracy_vs_noise.txt', 'r') as f:
with open(
    r"A:\src\multy_task\notebooks\train\reduce\lif_adex\romo_dm_ctx\accuracy_vs_noise.txt",
    "r",
) as f:
    while line := f.readline():
        t_x, t_y = parser(line)
        x.append(t_x)
        y.append(t_y)

fig, ax = plt.subplots()
ax.plot(x, y, ".", linestyle="--")
# ax.plot([.5]*2, [50, 100])
ax.set_ylabel("Accuracy,%")
ax.set_xlabel(r"$\sigma$")
ax.add_patch(
    patches.Rectangle(
        (0, 50), 0.5, 50, edgecolor="grey", facecolor="grey", alpha=0.5, fill=True
    )
)

plt.show()
plt.close()